{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "04_dropout_intro.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/mlelarge/dataflowr/blob/master/Notebooks/04_deeper/04_dropout_intro_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "metadata": {
        "id": "H-pGQemIx72p",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# DropOut intuition"
      ]
    },
    {
      "metadata": {
        "id": "UbUz5Ocsx72s",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "%matplotlib inline\n",
        "import matplotlib.pyplot as plt\n",
        "import torch"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "D8WXRqxNx72w",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# 1. Dropout from scratch"
      ]
    },
    {
      "metadata": {
        "id": "n87efjJpx72x",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "Let's code our own dropout function first."
      ]
    },
    {
      "metadata": {
        "id": "b5QmfUlJx72x",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "def dropout(X, drop_prob):\n",
        "    assert 0 <= drop_prob <= 1\n",
        "    # In this case, all elements are dropped out\n",
        "    if drop_prob == 1:\n",
        "        return torch.zeros_like(X)\n",
        "    mask = torch.rand(X.shape).uniform_() > drop_prob\n",
        "    return mask.float() * X / (1.0-drop_prob)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "2YvMdJrtx72z",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "X = torch.arange(16).view(2,8).float()\n",
        "print(X)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "iTu1ZNvzx722",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "print(dropout(X, 0))\n",
        "print(dropout(X, 0.5))\n",
        "print(dropout(X, 1))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "kClbySxGx724",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# 2. DropOut on a toy dataset"
      ]
    },
    {
      "metadata": {
        "id": "p6Yxd33Xx725",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "N_SAMPLES = 20\n",
        "N_HIDDEN = 300"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "BbaEOLDLx727",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "Let's generate some training data"
      ]
    },
    {
      "metadata": {
        "id": "aGacY3QHx728",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# training data\n",
        "x = torch.unsqueeze(torch.linspace(-1, 1, N_SAMPLES), 1)\n",
        "y = x + 0.3*torch.normal(torch.zeros(N_SAMPLES, 1), torch.ones(N_SAMPLES, 1))\n",
        "\n",
        "# test data\n",
        "test_x = torch.unsqueeze(torch.linspace(-1, 1, N_SAMPLES), 1)\n",
        "test_y = test_x + 0.3*torch.normal(torch.zeros(N_SAMPLES, 1), torch.ones(N_SAMPLES, 1))\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "a6O7oAzTx72-",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# show data\n",
        "plt.scatter(x.data.numpy(), y.data.numpy(), c='green', s=50, alpha=0.5, label='train')\n",
        "plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='orange', s=50, alpha=0.5, label='test')\n",
        "plt.legend(loc='upper left')\n",
        "plt.ylim((-2.5, 2.5))\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "BCkx0YJhx73D",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "#### Exercise"
      ]
    },
    {
      "metadata": {
        "id": "egV6cqNex73E",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "Create a network `net_simple` as `nn.Sequential` with the following structure: `Fully Connected N_HIDDEN -> ReLU -> Fully Connected N_HIDDEN -> ReLU -> Fully Connected 1`"
      ]
    },
    {
      "metadata": {
        "id": "NEA2ljVnx73F",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "net_simple = torch.nn.Sequential(\n",
        "#     TODO\n",
        ")\n",
        "\n",
        "print(net_simple)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "lyDCBZWqx73I",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "#### Exercise\n",
        "\n",
        "Now define the same architecture with `Dropout` layers in-between with $p=0.5$. Where will you place them?"
      ]
    },
    {
      "metadata": {
        "id": "PzduP4k8x73J",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "net_dropout = torch.nn.Sequential(\n",
        "#    TODO\n",
        ")\n",
        "\n",
        "print(net_dropout)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "txKmgBiDx73M",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "optimizer_simple = torch.optim.Adam(net_simple.parameters(), lr=0.01)\n",
        "optimizer_dropout = torch.optim.Adam(net_dropout.parameters(), lr=0.01)\n",
        "loss_fn = torch.nn.MSELoss()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "IttM4NmMx73P",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "for epoch in range(500):\n",
        "    pred_simple = net_simple(x)\n",
        "    pred_drop = net_dropout(x)\n",
        "    loss_simple = loss_fn(pred_simple, y)\n",
        "    loss_dropout = loss_fn(pred_drop, y)\n",
        "\n",
        "    optimizer_simple.zero_grad()\n",
        "    optimizer_dropout.zero_grad()\n",
        "    loss_simple.backward()\n",
        "    loss_dropout.backward()\n",
        "    optimizer_simple.step()\n",
        "    optimizer_dropout.step()\n",
        "\n",
        "    if epoch % 100 == 0:\n",
        "        # change to eval mode in order to fix drop out effect\n",
        "        net_simple.eval()\n",
        "        net_dropout.eval()  # parameters for dropout differ from train mode\n",
        "\n",
        "        # plotting\n",
        "        plt.cla()\n",
        "        test_pred_simple = net_simple(test_x)\n",
        "        test_pred_dropout = net_dropout(test_x)\n",
        "        plt.scatter(x.data.numpy(), y.data.numpy(), c='green', s=50, alpha=0.3, label='train')\n",
        "        plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='orange', s=50, alpha=0.3, label='test')\n",
        "        plt.plot(test_x.data.numpy(), test_pred_simple.data.numpy(), 'r-', lw=3, label='overfitting')\n",
        "        plt.plot(test_x.data.numpy(), test_pred_dropout.data.numpy(), 'b--', lw=3, label='dropout(50%)')\n",
        "        plt.text(0, -1.2, 'overfitting test loss=%.4f' % loss_fn(test_pred_simple, test_y).data.item(), fontdict={'size': 16, 'color':  'red'})\n",
        "        plt.text(0, -1.5, 'dropout test loss=%.4f' % loss_fn(test_pred_dropout, test_y).data.item(), fontdict={'size': 16, 'color': 'blue'})\n",
        "        plt.legend(loc='upper left'); plt.ylim((-2.5, 2.5));plt.pause(0.1)\n",
        "\n",
        "        # change back to train mode\n",
        "        net_simple.train()\n",
        "        net_dropout.train()\n",
        "        plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "Vc8m9b9ox73Q",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}